"""The purpose of this module is to transfer input of a certain task
(outputs stored on other nodes) before processing can begin on them.
This module will be used to pull input for the Reduce function

"""

import time
import shutil
import socket
import os
import ma.log
import ma.const
import ma.net.brdcstsv
from . import outputserver
from . import constants
import copy


class InputFetcher(object):
    """This class is used to fetch input from a list of tuples provided
    
    The InputFetcher's object can be used to broadcast request for files
    and then transfer each of them using TCP. If the file is local, it
    will be copied locally.
    """
    
    MAP_REDUCE_TYPE = 0
    MRIMPROV_TYPE = 1
    
    
    def __init__(self, output_server, nodetracker, fetch_type=MRIMPROV_TYPE):
        """Constructor for initializing the Input Fetcher. It starts the
        broadcast packet sender class 
        """
        #initialize __log for Input Fetcher
        self.__log = ma.log.get_logger("ma.commons")
        
        try:
            self.__broadcast_addr = str.strip(ma.const.XmlData.get_str_data(ma.const.xml_broadcast_address))
            self.__output_comm_port = ma.const.XmlData.get_int_data(ma.const.xml_net_output_srv_bind_port)
            
            # initiate the broadcast send 
            self.__brdcster = ma.net.brdcstsv.BroadcastPacketSender(self.__broadcast_addr, self.__output_comm_port)
            self.__log.debug('Init Broadcast packet sender')
            
            # the wait time period after the broadcast
            self.__wait_timeperiod = ma.const.XmlData.get_int_data(ma.const.xml_file_req_responce_wait)
            
            # the number of attempts for querying for outputs
            self.__brdcst_attempts = ma.const.XmlData.get_int_data(ma.const.xml_out_brdcts_rqst_attempts)
            
            # file transfer buffer size
            self.__file_transfer_buff_size = ma.const.XmlData.get_int_data(ma.const.xml_net_file_transfer_buff_size)
            
            # server for querying results of the broadcast
            self.__output_server = output_server
            
            # for querying which task outputs are local
            self.__nodetracker = nodetracker
            
            self.__fetch_type = fetch_type
            
            self.__log.info('Initialized the Input Fetcher class')
        except Exception as err_msg:
            self.__log.error("Error while creating Input Fetched object: %s", err_msg)
    
        
    def fetch_input(self, outputs_of_tasks, task_id, job_id):
        """This function is to be called every-time inputs for a reduce
        task are required. Task id is needed for MR so that the hash
        can be matched and hence given the appropriate output. The job_id
        needs to be provided so that the files can be copied in the
        appropriate input directory. This function returns the tuple in
        outputs_of_tasks which couldn't be transferred from the network
        """
        
        try:
            remaining_output_of_tasks = copy.deepcopy(outputs_of_tasks)
            
            # copy to local directory files which are local and update remaining tasks 
            #    that need to be shuffled
            remaining_output_of_tasks = self.__copy_local_task_outputs(remaining_output_of_tasks, task_id, job_id)
            
            self.__log.debug('Remaining outputs after local copy %s', str(remaining_output_of_tasks))
            
            # if outputs remaining after local copy
            if len(remaining_output_of_tasks) > 0:                
                for attmpt in range(self.__brdcst_attempts):
                    self.__log.debug('Fetching outputs attempt %d', attmpt+1)
                    
                    # create packet
                    pckt_out_msg = self.__build_broadcast_packet(remaining_output_of_tasks, task_id)
                    reply_msg = outputserver.PCKT_MSG % (outputserver.OUTPUT_AVAILABILITY_TAG, pckt_out_msg)
                    
                    self.__log.debug('Sending message %s', reply_msg)
                    
                    # dispatch broadcast packet
                    self.__brdcster.send_message(reply_msg)
                    self.__log.debug('Sent broadcast packet: %s', reply_msg)
                
                    # wait for a while
                    time.sleep(self.__wait_timeperiod)
                    
                    # tasks which need to be removed
                    remove_tasks = []
                    
                    # iterate through all outputs and see which 
                    for out in remaining_output_of_tasks:
                        # check if that file exists on some ip
                        host_ip = self.__output_server.return_output_host_ip(out[0], out[1], out[2], task_id)
                        
                        self.__log.debug('Output %s hosted at %s', str((out[0], out[1], out[2], task_id)), str(host_ip))
                        
                        if host_ip != None:
                            # transfer the file
                            self.__log.debug('Transferring file for %s', str(out))
                            result = self.__transfer_file(out, task_id, job_id, host_ip)
                            if result:
                                remove_tasks.append(out)
                                self.__log.debug('Transferred file for %s', str(out))
                            else:
                                self.__log.debug('Failed transferring file for %s', str(out))
                    
                    # remove tasks from remaining task list
                    for remove_task in remove_tasks:
                        remaining_output_of_tasks.remove(remove_task)
                    
                    # if no more file transfers remaining
                    if len(remaining_output_of_tasks) == 0:
                        break
                        
            self.__log.debug('Input fetched for job_id %d, task_id %d with these remaining %s', job_id, task_id, str(remaining_output_of_tasks))
            
            # returns the tasks which couldnt be retrieved from the network
            return remaining_output_of_tasks
        except Exception as err_msg:
            self.__log.error("Error while fetching input for task %d : %s", task_id, err_msg)
    
    
    def __build_broadcast_packet(self, outputs_of_tasks, task_id):
        """This function is called by fetch_input for creating an appropriate
        packet
        """
        task_outputs = ''
        for task_out in outputs_of_tasks:
            if task_outputs != '':
                task_outputs += outputserver.OUTPUT_ID_SEP
            task_outputs += str(task_out[0]) + str(task_out[1]) + str(task_out[2]) + '_' + str(task_id)
        
        return task_outputs 
    
    
    def __copy_local_task_outputs(self, outputs_of_tasks, task_id, job_id):
        """This function just copies the output files to output directory.
        This is useful one some tasks which need to be shuffled were 
        assigned to this node in the first place
        """
        try:
            remove_tasks = []
            
            # input directory where output found should be shifted
            input_dir_path = ma.const.JobsXmlData.get_filepath_str_data(ma.const.xml_local_input_temp_dir, job_id)
            
            # iterate through all lists 
            for i in range(len(outputs_of_tasks)):
                task_out = outputs_of_tasks[i]
                filepath = self.__nodetracker.getCompleteTaskOutputPath(task_out[0], task_out[1], task_out[2], task_id)
                
                # check if the input is available
                if filepath != None:
                    input_filepath = input_dir_path + os.sep + os.path.basename(filepath)
                    remove_tasks.append(task_out)
                    # copy from one output directory to input directory
                    shutil.copyfile(filepath, input_filepath)
                    self.__log.debug('Transferred file locally for input %s to %s', filepath, input_filepath)
            
            # remove from the task
            for i in remove_tasks:
                outputs_of_tasks.remove(i)
            
            return outputs_of_tasks
        
        except Exception as err_msg:
            self.__log.error("Error while copying inputs locally : %s", err_msg)
    
    
    def __transfer_file(self, task_out_info, task_id, job_id, host_ip):
        """This function is used to transfer files from the server to this
        client. This function returns true on success
        """
        try:
            # input directory where output found should be shifted
            input_dir_path = ma.const.JobsXmlData.get_filepath_str_data(ma.const.xml_local_input_temp_dir, job_id)
            
            # input filepath where the file needs to be transferred to
            if self.__fetch_type == InputFetcher.MRIMPROV_TYPE:
                # if the fetcher made is for MR+, don't worry about any partitioning code
                if constants.MAP == task_out_info[1]:
                    input_write_filepath = input_dir_path + os.sep + ma.const.JobsXmlData.get_str_data(ma.const.xml_map_output_filename, task_out_info[0], task_out_info[2])
                else:
                    input_write_filepath = input_dir_path + os.sep + ma.const.JobsXmlData.get_str_data(ma.const.xml_reduce_output_filename, task_out_info[0], task_out_info[2])
            else:
                # if the fetcher made is for MapReduce, use the partitioning 
                #    code in filename, and raise error if trying to transfer 
                #    reduce output
                if constants.MAP == task_out_info[1]:
                    # TODO: insert partitioner for task id
                    input_write_filepath = input_dir_path + os.sep + ma.const.JobsXmlData.get_str_data(ma.const.xml_map_output_filename_with_key, task_out_info[0], task_out_info[2], task_id)
                else:
                    self.__log.error('Trying to transfer a reduce file of job_id %d and task_id %d', task_out_info[0], task_out_info[2])
                    raise IOError('Trying to transfer a reduce file of job_id %d and task_id %d' % (task_out_info[0], task_out_info[2])) 
            
            # open file for writing. And also truncate
            fd = open(input_write_filepath, 'w+')
            
            self.__log.debug('Opened file for transfer %s', input_write_filepath)
            
            out_info_msg = str(task_out_info[0]) + str(task_out_info[1]) + str(task_out_info[2]) + '_' + str(task_id)
            pckt_msg = outputserver.PCKT_MSG % (outputserver.REQUEST_FILE_TAG, out_info_msg)
            
            # establish TCP connection
            self.__log.debug('Establishing connection with the node hosting the file %s', task_out_info)
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect((host_ip, self.__output_comm_port))

            self.__log.debug('Sending request for file: %s to %s', pckt_msg, str((host_ip, self.__output_comm_port)))
            
            # send request packet
            bytes_sent = 0
            total_to_send = len(pckt_msg)
            if bytes_sent < total_to_send:
                sent = sock.send(pckt_msg[bytes_sent:])
                bytes_sent += sent
            
            # transfer file data
            self.__log.debug('Transferring file %s from %s', task_out_info, host_ip)
            data = sock.recv(self.__file_transfer_buff_size)
            
            # read the data
            while data:
                # transfer the buffer read
                fd.write(data)
                data = sock.recv(self.__file_transfer_buff_size)
            
            # close the file
            fd.close()
            
            # Clean up of socket
            sock.close()
            
            return True
        except Exception as err_msg:
            self.__log.error("Error while transferring file : %s", err_msg)
            return False


def test():
    nt = outputserver.DummyNodeTracker()
    outserver = outputserver.OutputServer(nt)
    inpfetcher = InputFetcher(outserver, nt)
    out_of_tasks = [(1,'M',1),(1,'M',2),(1,'M',3),(1,'M',4)]
    inpfetcher.fetch_input(out_of_tasks, 1, 1)
    